{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QcJK3kXl--c3"
   },
   "source": [
    "# EECS 498-007/598-005 Assignment 1-2: K-Nearest Neighbors (k-NN)\n",
    "\n",
    "Before we start, please put your name and UMID in following format\n",
    "\n",
    ": Firstname LASTNAME, #00000000   //   e.g.) Justin JOHNSON, #12345678"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7sA2iBcm_cPb"
   },
   "source": [
    "**Your Answer:**   \n",
    "Your NAME, #XXXXXXXX"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Qc83ETI1a3o9"
   },
   "source": [
    "In this notebook you will implement a K-Nearest Neighbors classifier on the [CIFAR-10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).\n",
    "\n",
    "Recall that the K-Nearest Neighbor classifier does the following:\n",
    "- During training, the classifier simply memorizes the training data\n",
    "- During testing, test images are compared to each training image; the predicted label is the majority vote among the K nearest training examples.\n",
    "\n",
    "After implementing the K-Nearest Neighbor classifier, you will use *cross-validation* to find the best value of K.\n",
    "\n",
    "The goals of this exercise are to go through a simple example of the data-driven image classification pipeline, and also to practice writing efficient, vectorized code in [PyTorch](https://pytorch.org/)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "hQrEwOpXb9Gh"
   },
   "source": [
    "# Setup Code\n",
    "Before getting started we need to run some boilerplate code to set up our environment. You'll need to rerun this setup code each time you start the notebook.\n",
    "\n",
    "First, run this cell load the [autoreload](https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html?highlight=autoreload) extension. This allows us to edit `.py` source files, and re-import them into the notebook for a seamless editing and debugging experience."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "73cuTs3re6wg"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Cnf0BfHZfWzO"
   },
   "source": [
    "### Google Colab Setup\n",
    "Next we need to run a few commands to set up our environment on Google Colab. If you are running this notebook on a local machine you can skip this section.\n",
    "\n",
    "Run the following cell to mount your Google Drive. Follow the link, sign in to your Google account (the same account you used to store this notebook!) and copy the authorization code into the text box that appears below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 122
    },
    "colab_type": "code",
    "id": "VxbQtNB6fWzO",
    "outputId": "d1d84d2e-beb3-4c5d-e50c-e8272eb6067d"
   },
   "outputs": [],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IW2eBtZsfWzR"
   },
   "source": [
    "Now recall the path in your Google Drive where you uploaded this notebook, fill it in below. If everything is working correctly then running the folowing cell should print the filenames from the assignment:\n",
    "\n",
    "```\n",
    "['pytorch101.py', 'knn.py', 'knn.ipynb', 'eecs598', 'pytorch101.ipynb']\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54
    },
    "colab_type": "code",
    "id": "kfzFXbiEfWzS",
    "outputId": "1dadb021-6bd0-4843-82f8-fd3201dd17aa"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# TODO: Fill in the Google Drive path where you uploaded the assignment\n",
    "# Example: If you create a 2020FA folder and put all the files under A1 folder, then '2020FA/A1'\n",
    "# GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = '2020FA/A1'\n",
    "GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = None\n",
    "GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)\n",
    "print(os.listdir(GOOGLE_DRIVE_PATH))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aY_PV4eQfWzU"
   },
   "source": [
    "Once you have successfully mounted your Google Drive and located the path to this assignment, run th following cell to allow us to import from the `.py` files of this assignment. If it works correctly, it should print the message:\n",
    "\n",
    "```\n",
    "Hello from knn.py!\n",
    "```\n",
    "\n",
    "as well as the last edit time for the file `pytorch101.py`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "VGbUf6nTfWzV",
    "outputId": "9f3c78b7-c5d6-451e-d6c7-a6a2c4b41235"
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(GOOGLE_DRIVE_PATH)\n",
    "\n",
    "import time, os\n",
    "os.environ[\"TZ\"] = \"US/Eastern\"\n",
    "time.tzset()\n",
    "\n",
    "from knn import hello\n",
    "hello()\n",
    "\n",
    "knn_path = os.path.join(GOOGLE_DRIVE_PATH, 'pytorch101.py')\n",
    "knn_edit_time = time.ctime(os.path.getmtime(knn_path))\n",
    "print('knn.py last edited on %s' % knn_edit_time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SWSgBT8Wf3tW"
   },
   "source": [
    "# Data preprocessing / Visualization"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "emQnvtnFeX1H"
   },
   "source": [
    "## Setup code\n",
    "Run some setup code for this notebook: Import some useful packages and increase the default figure size."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Tf64a0TS8zh7"
   },
   "outputs": [],
   "source": [
    "import eecs598\n",
    "import torch\n",
    "import torchvision\n",
    "import matplotlib.pyplot as plt\n",
    "import statistics\n",
    "\n",
    "plt.rcParams['figure.figsize'] = (10.0, 8.0)\n",
    "plt.rcParams['font.size'] = 16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GSd6jQb4epkC"
   },
   "source": [
    "## Load the CIFAR-10 dataset\n",
    "The utility function `eecs598.data.cifar10()` returns the entire CIFAR-10 dataset as a set of four **Torch tensors**:\n",
    "\n",
    "- `x_train` contains all training images (real numbers in the range $[0, 1]$)\n",
    "- `y_train` contains all training labels (integers in the range $[0, 9]$)\n",
    "- `x_test` contains all test images\n",
    "- `y_test` contains all test labels\n",
    "\n",
    "This function automatically downloads the CIFAR-10 dataset the first time you run it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 185,
     "referenced_widgets": [
      "5bbe1666cf604cd3ad400203c5e2c1d6",
      "43350d9be6d24650bb45d90d0108fba0",
      "491a89c56fa64153ac7cb50669421221",
      "759416a5a4ca48f78f37cced1c29bcba",
      "f70ec43cca1047eb8508ce8a0741997c",
      "ba451d02e663431586e490f97d21f780",
      "09cc8a49299948cfb7777d714b309cb6",
      "c575bd2f78744ac1a6ceb5b1d9d14776"
     ]
    },
    "colab_type": "code",
    "id": "y2JiLb-R9bFb",
    "outputId": "eabd0fe0-da2a-4eb5-9e56-5ef0a59e3521"
   },
   "outputs": [],
   "source": [
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10()\n",
    "\n",
    "print('Training set:', )\n",
    "print('  data shape:', x_train.shape)\n",
    "print('  labels shape: ', y_train.shape)\n",
    "print('Test set:')\n",
    "print('  data shape: ', x_test.shape)\n",
    "print('  labels shape', y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AKKdLGIIffYx"
   },
   "source": [
    "## Visualize the dataset\n",
    "To give you a sense of the nature of the images in CIFAR-10, this cell visualizes some random examples from the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 483
    },
    "colab_type": "code",
    "id": "UMNVrzrd-d_y",
    "outputId": "1bd00712-3a50-4e4d-fb39-a9353c25c56e"
   },
   "outputs": [],
   "source": [
    "import random\n",
    "from torchvision.utils import make_grid\n",
    "\n",
    "classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']\n",
    "samples_per_class = 12\n",
    "samples = []\n",
    "for y, cls in enumerate(classes):\n",
    "    plt.text(-4, 34 * y + 18, cls, ha='right')\n",
    "    idxs, = (y_train == y).nonzero(as_tuple=True)\n",
    "    for i in range(samples_per_class):\n",
    "        idx = idxs[random.randrange(idxs.shape[0])].item()\n",
    "        samples.append(x_train[idx])\n",
    "img = torchvision.utils.make_grid(samples, nrow=samples_per_class)\n",
    "plt.imshow(eecs598.tensor_to_image(img))\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-nLyYUhBgDKp"
   },
   "source": [
    "## Subsample the dataset\n",
    "When implementing machine learning algorithms, it's usually a good idea to use a small sample of the full dataset. This way your code will run much faster, allowing for more interactive and efficient development. Once you are satisfied that you have correctly implemented the algorithm, you can then rerun with the entire dataset.\n",
    "\n",
    "The function `eecs598.data.cifar10()` can automatically subsample the CIFAR10 dataset for us. To see how to use it, we can check the documentation using the built-in `help` command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 323
    },
    "colab_type": "code",
    "id": "K5CYSO_ugyno",
    "outputId": "2d04f92c-92b8-4aa0-b1a8-f29765748353"
   },
   "outputs": [],
   "source": [
    "help(eecs598.data.cifar10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DtBIn0xjhPMd"
   },
   "source": [
    "We will subsample the data to use only 500 training examples and 100 test examples:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "id": "FFmXwZbnG9ki",
    "outputId": "dc47bcd0-d46b-41d5-bc75-052ff7043d39"
   },
   "outputs": [],
   "source": [
    "num_train = 500\n",
    "num_test = 250\n",
    "\n",
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10(num_train, num_test)\n",
    "\n",
    "print('Training set:', )\n",
    "print('  data shape:', x_train.shape)\n",
    "print('  labels shape: ', y_train.shape)\n",
    "print('Test set:')\n",
    "print('  data shape: ', x_test.shape)\n",
    "print('  labels shape', y_test.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-M0pmnWwgFu5"
   },
   "source": [
    "# K-Nearest Neighbors (k-NN)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NOZTkdiSmUFc"
   },
   "source": [
    "## Compute distances: Naive implementation\n",
    "Now that we have examined and prepared our data, it is time to implement the kNN classifier. We can break the process down into two steps:\n",
    "\n",
    "1. Compute the (squared Euclidean) distances between all training examples and all test examples\n",
    "2. Given these distances, for each test example find its k nearest neighbors and have them vote for the label to output\n",
    "\n",
    "Lets begin with computing the distance matrix between all training and test examples. First we will implement a naive version of the distance computation, using explicit loops over the training and test sets. In the file `knn.py`, implement the function `compute_distances_two_loops`.\n",
    "\n",
    "**NOTE: When implementing distance functions for this assignment, you may not use functions `torch.norm` or `torch.dist` (or their instance method variants `x.norm` / `x.dist`); you may not use any functions from `torch.nn` or `torch.nn.functional`.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "oHq2bs_MnqVM",
    "outputId": "29d3c420-e982-4a1b-f9c8-61dfe4bbe7b9"
   },
   "outputs": [],
   "source": [
    "from knn import compute_distances_two_loops\n",
    "\n",
    "torch.manual_seed(0)\n",
    "num_train = 500\n",
    "num_test = 250\n",
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10(num_train, num_test)\n",
    "\n",
    "dists = compute_distances_two_loops(x_train, x_test)\n",
    "print('dists has shape: ', dists.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MGdFIqBEpPcQ"
   },
   "source": [
    "As a visual debugging step, we can visualize the distance matrix, where each row is a test example and each column is a training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 512
    },
    "colab_type": "code",
    "id": "dshO3kmOKk0T",
    "outputId": "4ef94d74-8700-4dd8-f9f4-936fcf0e5125"
   },
   "outputs": [],
   "source": [
    "plt.imshow(dists.numpy(), cmap='gray', interpolation='none')\n",
    "plt.colorbar()\n",
    "plt.xlabel('test')\n",
    "plt.ylabel('train')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aHkuvdr_1HqC"
   },
   "source": [
    "## Compute distances: Vectorization\n",
    "Our implementation of the distance computation above is fairly inefficient since it uses nested Python loops over the training and test sets.\n",
    "\n",
    "When implementing algorithms in PyTorch, it's best to avoid loops in Python if possible. Instead it is preferable to implement your computation so that all loops happen inside PyTorch functions. This will usually be much faster than writing your own loops in Python, since PyTorch functions can be internally optimized to iterate efficiently, possibly using multiple threads. This is especially important when using a GPU to accelerate your code.\n",
    "\n",
    "The process of eliminating explict loops from your code is called **vectorization**. Sometimes it is straighforward to vectorize code originally written with loops; other times vectorizing requires thinking about the problem in a new way. We will use vectorization to improve the speed of our distance computation function.\n",
    "\n",
    "As a first step toward vectorizing our distance computation, you will implement a version that uses only a single Python loop over the training data. In the file `knn.py`, complete the implementation of the function `compute_distances_one_loop`.\n",
    "\n",
    "We can check the correctness of our one-loop implementation by comparing it with our two-loop implementation on some randomly generated data.\n",
    "\n",
    "Note that we do the comparison with 64-bit floating points for increased numeric precision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "ujU8bWch4EmK",
    "outputId": "d322edba-53d8-47a5-cd90-069ba57e0b12"
   },
   "outputs": [],
   "source": [
    "from knn import compute_distances_one_loop\n",
    "from knn import compute_distances_two_loops\n",
    "\n",
    "torch.manual_seed(0)\n",
    "x_train_rand = torch.randn(100, 3, 16, 16, dtype=torch.float64)\n",
    "x_test_rand = torch.randn(100, 3, 16, 16, dtype=torch.float64)\n",
    "\n",
    "dists_one = compute_distances_one_loop(x_train_rand, x_test_rand)\n",
    "dists_two = compute_distances_two_loops(x_train_rand, x_test_rand)\n",
    "difference = (dists_one - dists_two).pow(2).sum().sqrt().item()\n",
    "print('Difference: ', difference)\n",
    "if difference < 1e-4:\n",
    "    print('Good! The distance matrices match')\n",
    "else:\n",
    "    print('Uh-oh! The distance matrices are different')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gqtIsY6x_kb9"
   },
   "source": [
    "You will now implement a fully vectorized version of the distance computation function\n",
    "that does not use any Python loops. In the file `knn.py`, implement the function `compute_distances_no_loops`.\n",
    "\n",
    "As before, we can check the correctness of our implementation by comparing the fully vectorized version against the original naive version:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "1RY8QBeS9WYK",
    "outputId": "b8ed1d8e-cd2f-4a84-864a-08dbd51c1698"
   },
   "outputs": [],
   "source": [
    "from knn import compute_distances_two_loops\n",
    "from knn import compute_distances_no_loops\n",
    "\n",
    "torch.manual_seed(0)\n",
    "x_train_rand = torch.randn(100, 3, 16, 16, dtype=torch.float64)\n",
    "x_test_rand = torch.randn(100, 3, 16, 16, dtype=torch.float64)\n",
    "\n",
    "dists_two = compute_distances_two_loops(x_train_rand, x_test_rand)\n",
    "dists_none = compute_distances_no_loops(x_train_rand, x_test_rand)\n",
    "difference = (dists_two - dists_none).pow(2).sum().sqrt().item()\n",
    "print('Difference: ', difference)\n",
    "if difference < 1e-4:\n",
    "  print('Good! The distance matrices match')\n",
    "else:\n",
    "  print('Uh-oh! The distance matrices are different')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0JPMM0-BBGmt"
   },
   "source": [
    "We can now compare the speed of our three implementations. If you've implemented everything properly, the one-loop implementation should take less than 4 seconds to run, and the fully vectorized implementation should take less than 0.1 seconds to run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "IN9cntDC5c5q",
    "outputId": "b78643e0-ce71-41b1-ffed-6f3e625ceafe"
   },
   "outputs": [],
   "source": [
    "import time\n",
    "from knn import compute_distances_two_loops\n",
    "from knn import compute_distances_one_loop\n",
    "from knn import compute_distances_no_loops\n",
    "\n",
    "def timeit(f, *args):\n",
    "    tic = time.time()\n",
    "    f(*args) \n",
    "    toc = time.time()\n",
    "    return toc - tic\n",
    "\n",
    "torch.manual_seed(0)\n",
    "x_train_rand = torch.randn(500, 3, 32, 32)\n",
    "x_test_rand = torch.randn(500, 3, 32, 32)\n",
    "\n",
    "two_loop_time = timeit(compute_distances_two_loops, x_train_rand, x_test_rand)\n",
    "print('Two loop version took %.2f seconds' % two_loop_time)\n",
    "\n",
    "one_loop_time = timeit(compute_distances_one_loop, x_train_rand, x_test_rand)\n",
    "speedup = two_loop_time / one_loop_time\n",
    "print('One loop version took %.2f seconds (%.1fX speedup)'\n",
    "      % (one_loop_time, speedup))\n",
    "\n",
    "no_loop_time = timeit(compute_distances_no_loops, x_train_rand, x_test_rand)\n",
    "speedup = two_loop_time / no_loop_time\n",
    "print('No loop version took %.2f seconds (%.1fX speedup)'\n",
    "      % (no_loop_time, speedup))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EudsSj5TrGGF"
   },
   "source": [
    "## Predict labels\n",
    "Now that we have a method for computing distances between training and test examples, we need to implement a function that uses those distances together with the training labels to predict labels for test samples.\n",
    "\n",
    "In the file `knn.py`, implement the function `predict_labels`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "MWk4BTMKfWz8",
    "outputId": "e4aa005a-ab93-4232-b27c-c1c365246f7d"
   },
   "outputs": [],
   "source": [
    "from knn import predict_labels\n",
    "\n",
    "torch.manual_seed(0)\n",
    "dists = torch.tensor([\n",
    "    [0.3, 0.4, 0.1],\n",
    "    [0.1, 0.5, 0.5],\n",
    "    [0.4, 0.1, 0.2],\n",
    "    [0.2, 0.2, 0.4],\n",
    "    [0.5, 0.3, 0.3],\n",
    "])\n",
    "y_train = torch.tensor([0, 1, 0, 1, 2])\n",
    "y_pred_expected = torch.tensor([1, 0, 0])\n",
    "y_pred = predict_labels(dists, y_train, k=3)\n",
    "correct = y_pred.tolist() == y_pred_expected.tolist()\n",
    "print('Correct: ', correct)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fMBf1Z6VF9hx"
   },
   "source": [
    "Now we have implemented all the required functionality for the K-Nearest Neighbor classifier. In the file `knn.py`, complete the implementation of the `KnnClassifer` class.\n",
    "\n",
    "We can get some intuition into the KNN classifier by visualizing its predictions on toy 2D data. Here we will generate some random training and test points in 2D, and assign random labels to the training points. We can then make predictions for the test points, and visualize both training and test points. Training points are shown as stars, and tet points are shown as small transparent circles. The color of each point denots its label -- ground-truth label for training points, and predicted label for test points."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "zTa7xowOfWz-",
    "outputId": "742be2a6-8d9c-47f2-8773-84f558b4f42d"
   },
   "outputs": [],
   "source": [
    "from knn import KnnClassifier\n",
    "\n",
    "num_test = 10000\n",
    "num_train = 20\n",
    "num_classes = 5\n",
    "\n",
    "# Generate random training and test data\n",
    "torch.manual_seed(128)\n",
    "x_train = torch.rand(num_train, 2)\n",
    "y_train = torch.randint(num_classes, size=(num_train,))\n",
    "x_test = torch.rand(num_test, 2)\n",
    "classifier = KnnClassifier(x_train, y_train)\n",
    "\n",
    "# Plot predictions for different values of k\n",
    "for k in [1, 3, 5]:\n",
    "    y_test = classifier.predict(x_test, k=k)\n",
    "    plt.gcf().set_size_inches(8, 8)\n",
    "    class_colors = ['r', 'g', 'b', 'k', 'y']\n",
    "    train_colors = [class_colors[c] for c in y_train]\n",
    "    test_colors = [class_colors[c] for c in y_test]\n",
    "    plt.scatter(x_test[:, 0], x_test[:, 1],\n",
    "                color=test_colors, marker='o', s=32, alpha=0.05)\n",
    "    plt.scatter(x_train[:, 0], x_train[:, 1],\n",
    "                color=train_colors, marker='*', s=128.0)\n",
    "    plt.title('Predictions for k = %d' % k, size=16)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2tgNeDX0fW0A"
   },
   "source": [
    "We can use the exact same KNN code to perform image classification on CIFAR-10!\n",
    "\n",
    "Now lets put everything together and test our K-NN clasifier on a subset of CIFAR-10, using k=1:\n",
    "\n",
    "If you've implemented everything correctly you should see an accuracy of about 27%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "W5GVNBh0ySGN",
    "outputId": "32a4ed21-8d39-4c03-ed72-59a6e963da8e"
   },
   "outputs": [],
   "source": [
    "from knn import KnnClassifier\n",
    "\n",
    "torch.manual_seed(0)\n",
    "num_train = 5000\n",
    "num_test = 500\n",
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10(num_train, num_test)\n",
    "\n",
    "classifier = KnnClassifier(x_train, y_train)\n",
    "classifier.check_accuracy(x_test, y_test, k=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QQwHpcPrIF5u"
   },
   "source": [
    "Now lets increase to k=5. You should see a slightly higher accuracy than k=1:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "_a4zwcTe0PIK",
    "outputId": "f1d12b4a-8a70-4b8c-a6d4-8565036e57dd"
   },
   "outputs": [],
   "source": [
    "from knn import KnnClassifier\n",
    "\n",
    "torch.manual_seed(0)\n",
    "num_train = 5000\n",
    "num_test = 500\n",
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10(num_train, num_test)\n",
    "\n",
    "classifier = KnnClassifier(x_train, y_train)\n",
    "classifier.check_accuracy(x_test, y_test, k=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QNyZLRmaIgT0"
   },
   "source": [
    "## Cross-validation\n",
    "We have not implemented the full k-Nearest Neighbor classifier, but the choice of $k=5$ was arbitrary. We will use **cross-validation** to set this hyperparameter in a more principled manner.\n",
    "\n",
    "In the file `knn.py`, implement the function `knn_cross_validate` to perform cross-validation on k."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 272
    },
    "colab_type": "code",
    "id": "pA5MrumnLk5B",
    "outputId": "be8c8630-d020-4f79-f336-b76ce1f4bf6c"
   },
   "outputs": [],
   "source": [
    "from knn import knn_cross_validate\n",
    "\n",
    "torch.manual_seed(0)\n",
    "num_train = 5000\n",
    "num_test = 500\n",
    "x_train, y_train, x_test, y_test = eecs598.data.cifar10(num_train, num_test)\n",
    "\n",
    "k_to_accuracies = knn_cross_validate(x_train, y_train, num_folds=5)\n",
    "\n",
    "for k, accs in sorted(k_to_accuracies.items()):\n",
    "  print('k = %d got accuracies: %r' % (k, accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 528
    },
    "colab_type": "code",
    "id": "vMtPikIsNxl2",
    "outputId": "0443bc91-e5ec-4fb9-e430-62806604609e"
   },
   "outputs": [],
   "source": [
    "ks, means, stds = [], [], []\n",
    "torch.manual_seed(0)\n",
    "for k, accs in sorted(k_to_accuracies.items()):\n",
    "  plt.scatter([k] * len(accs), accs, color='g')\n",
    "  ks.append(k)\n",
    "  means.append(statistics.mean(accs))\n",
    "  stds.append(statistics.stdev(accs))\n",
    "plt.errorbar(ks, means, yerr=stds)\n",
    "plt.xlabel('k')\n",
    "plt.ylabel('Cross-validation accuracy')\n",
    "plt.title('Cross-validation on k')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "XZ3Ue0bxmObU"
   },
   "source": [
    "Now we can use the results of cross-validation to select the best value for k, and rerun the classifier on our full 5000 set of training examples.\n",
    "\n",
    "You should get an accuracy above 28%."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "NBZfp1UtWyoG",
    "outputId": "b914f55f-070b-46b5-8bca-f84d7e82d88f"
   },
   "outputs": [],
   "source": [
    "from knn import KnnClassifier\n",
    "from knn import knn_get_best_k\n",
    "\n",
    "best_k = 1\n",
    "torch.manual_seed(0)\n",
    "\n",
    "best_k = knn_get_best_k(k_to_accuracies)    \n",
    "print('Best k is ', best_k)\n",
    "\n",
    "classifier = KnnClassifier(x_train, y_train)\n",
    "classifier.check_accuracy(x_test, y_test, k=best_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "R1LevOE5mYJh"
   },
   "source": [
    "Finally, we can use our chosen value of k to run on the entire training and test sets.\n",
    "\n",
    "This may take a while to run, since the full training and test sets have 50k and 10k examples respectively. You should get an accuracy above 33%.\n",
    "\n",
    "**Run this only once!**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "5gcXjsjFkcGV",
    "outputId": "a5d21480-6f1d-456c-8b54-d56ba26f0550"
   },
   "outputs": [],
   "source": [
    "from knn import KnnClassifier\n",
    "\n",
    "torch.manual_seed(0)\n",
    "x_train_all, y_train_all, x_test_all, y_test_all = eecs598.data.cifar10()\n",
    "classifier = KnnClassifier(x_train_all, y_train_all)\n",
    "classifier.check_accuracy(x_test_all, y_test_all, k=best_k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_eeuN70qb1oy"
   },
   "source": [
    "## Submit Your Work\n",
    "After completing both notebooks for this assignment (`pytorch101.ipynb` and this notebook, `knn.ipynb`), run the following cell to create a `.zip` file for you to download and turn in:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "3kXg-9z8b1oz",
    "outputId": "072fb23b-b4ab-4942-ae58-1c89329d4829"
   },
   "outputs": [],
   "source": [
    "from eecs598.submit import make_a1_submission\n",
    "\n",
    "make_a1_submission(GOOGLE_DRIVE_PATH)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "hQrEwOpXb9Gh",
    "Cnf0BfHZfWzO",
    "SWSgBT8Wf3tW",
    "emQnvtnFeX1H",
    "GSd6jQb4epkC",
    "AKKdLGIIffYx",
    "-nLyYUhBgDKp",
    "-M0pmnWwgFu5",
    "NOZTkdiSmUFc",
    "aHkuvdr_1HqC",
    "EudsSj5TrGGF",
    "QNyZLRmaIgT0",
    "_eeuN70qb1oy"
   ],
   "name": "knn.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "09cc8a49299948cfb7777d714b309cb6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "43350d9be6d24650bb45d90d0108fba0": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "491a89c56fa64153ac7cb50669421221": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_ba451d02e663431586e490f97d21f780",
      "max": 1,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_f70ec43cca1047eb8508ce8a0741997c",
      "value": 1
     }
    },
    "5bbe1666cf604cd3ad400203c5e2c1d6": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_491a89c56fa64153ac7cb50669421221",
       "IPY_MODEL_759416a5a4ca48f78f37cced1c29bcba"
      ],
      "layout": "IPY_MODEL_43350d9be6d24650bb45d90d0108fba0"
     }
    },
    "759416a5a4ca48f78f37cced1c29bcba": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_c575bd2f78744ac1a6ceb5b1d9d14776",
      "placeholder": "​",
      "style": "IPY_MODEL_09cc8a49299948cfb7777d714b309cb6",
      "value": " 170500096/? [00:08&lt;00:00, 20470714.87it/s]"
     }
    },
    "ba451d02e663431586e490f97d21f780": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "c575bd2f78744ac1a6ceb5b1d9d14776": {
     "model_module": "@jupyter-widgets/base",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "f70ec43cca1047eb8508ce8a0741997c": {
     "model_module": "@jupyter-widgets/controls",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": "initial"
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
